{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows how the backwards pass of a layernorm operation can be done using cudnn.\n",
    "\n",
    "$$\\text{LayerNorm}(x) = \\frac{x-\\mu}{\\sqrt{\\sigma^2 + \\epsilon}}\\cdot\\gamma+\\beta$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/01_matmul_bias.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites and Setup\n",
    "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('nvidia-smi')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If running on Colab, you will need to install the cudnn python interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('pip install nvidia-cudnn-frontend')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### General Setup\n",
    "Create a cudnn handle, which is a per device handle used to initialize cudnn context."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cudnn\n",
    "import torch\n",
    "import sys\n",
    "\n",
    "torch.manual_seed(1)\n",
    "handle = cudnn.create_handle()\n",
    "\n",
    "print(\"Running with cudnn backend version:\", cudnn.backend_version())\n",
    "\n",
    "assert torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LayerNorm Training\n",
    " Problem Sizes\n",
    "- Batch Size: 4\n",
    "- Sequence Size: 1024\n",
    "- Embedding Dimension: 768"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch, seq_size, embedding_dim = 4, 1024, 768\n",
    "\n",
    "input_type = torch.float16\n",
    "\n",
    "# Epsilon is a small number to prevent division by 0.\n",
    "epsilon_value = 1e-3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create input tensor GPU buffers. We use PyTorch to allocate GPU tensors so we can reuse them easily when we calculate reference outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input tensor memory, initialize them to random numbers\n",
    "x_gpu = torch.randn(\n",
    "    batch * seq_size,\n",
    "    embedding_dim,\n",
    "    1,\n",
    "    1,\n",
    "    dtype=input_type,\n",
    "    requires_grad=True,\n",
    "    device=\"cuda\",\n",
    ").to(memory_format=torch.channels_last)\n",
    "scale_gpu = torch.randn(\n",
    "    1, embedding_dim, 1, 1, dtype=input_type, requires_grad=True, device=\"cuda\"\n",
    ").to(memory_format=torch.channels_last)\n",
    "bias_gpu = torch.randn(\n",
    "    1, embedding_dim, 1, 1, dtype=input_type, requires_grad=True, device=\"cuda\"\n",
    ").to(memory_format=torch.channels_last)\n",
    "\n",
    "# set epsilon to epsilon_value, allocate on cpu.\n",
    "epsilon_cpu = torch.full(\n",
    "    (1, 1, 1, 1), epsilon_value, dtype=torch.float32, requires_grad=False, device=\"cpu\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create reference computation and output tensor GPU buffers using PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create the reference computation outputs here before the cuDNN computation, in order to use .empty_like() to create our output buffers\n",
    "out_expected = torch.nn.functional.layer_norm(\n",
    "    x_gpu,\n",
    "    [embedding_dim, 1, 1],\n",
    "    weight=scale_gpu.squeeze(0),\n",
    "    bias=bias_gpu.squeeze(0),\n",
    "    eps=epsilon_value,\n",
    ")\n",
    "\n",
    "mean_expected = x_gpu.to(torch.float32).mean(dim=(1, 2, 3), keepdim=True)\n",
    "\n",
    "inv_var_expected = torch.rsqrt(\n",
    "    torch.var(x_gpu.to(torch.float32), dim=(1, 2, 3), keepdim=True) + epsilon_value\n",
    ")\n",
    "\n",
    "# allocate output tensor memory using PyTorch\n",
    "out_gpu = torch.empty_like(out_expected)\n",
    "mean_gpu = torch.empty_like(mean_expected)\n",
    "inv_var_gpu = torch.empty_like(inv_var_expected)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create cuDNN Foward Graph and tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the cuDNN graph\n",
    "fwd_graph = cudnn.pygraph(\n",
    "    handle=handle,\n",
    "    intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")\n",
    "\n",
    "# Create tensor handles with the fwd_graph\n",
    "x = fwd_graph.tensor_like(x_gpu.detach()).set_name(\"X\")\n",
    "scale = fwd_graph.tensor_like(scale_gpu.detach()).set_name(\"scale\")\n",
    "bias = fwd_graph.tensor_like(bias_gpu.detach()).set_name(\"bias\")\n",
    "epsilon = fwd_graph.tensor_like(epsilon_cpu).set_name(\"epsilon\")\n",
    "\n",
    "# Add a layernorm operation\n",
    "(out, mean, inv_var) = fwd_graph.layernorm(\n",
    "    name=\"layernorm\",\n",
    "    input=x,\n",
    "    norm_forward_phase=cudnn.norm_forward_phase.TRAINING,\n",
    "    scale=scale,\n",
    "    bias=bias,\n",
    "    epsilon=epsilon,\n",
    ")\n",
    "\n",
    "# Enable all outputs\n",
    "out.set_name(\"output\").set_output(True).set_data_type(out_expected.dtype)\n",
    "mean.set_name(\"mean\").set_output(True).set_data_type(mean_expected.dtype)\n",
    "inv_var.set_name(\"inv_var\").set_output(True).set_data_type(inv_var_expected.dtype);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Validate and build the forward graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the fwd_graph\n",
    "fwd_graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "\n",
    "# To run this block more than once, we need to re-run the previous block to get a new fwd_graph.\n",
    "# The same instance of a fwd_graph should not be built twice."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Execute the forward graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of mapping UIDs to memory (as in 20_layernorm.ipynb), we can directly map handles to memory. This is simpler but slightly slower to execute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mapping of (handles -> memory)\n",
    "variant_pack = {\n",
    "    x: x_gpu.detach(),\n",
    "    scale: scale_gpu.detach(),\n",
    "    bias: bias_gpu.detach(),\n",
    "    epsilon: epsilon_cpu,\n",
    "    out: out_gpu,\n",
    "    mean: mean_gpu,\n",
    "    inv_var: inv_var_gpu,\n",
    "}\n",
    "\n",
    "workspace = torch.empty(\n",
    "    fwd_graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8\n",
    ")\n",
    "fwd_graph.execute(variant_pack, workspace)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test cuDNN's output against PyTorch's and check correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.synchronize()\n",
    "\n",
    "# reference output\n",
    "torch.testing.assert_close(out_gpu, out_expected, rtol=5e-3, atol=5e-3)\n",
    "torch.testing.assert_close(inv_var_gpu, inv_var_expected, rtol=5e-3, atol=5e-3)\n",
    "torch.testing.assert_close(mean_gpu, mean_expected, rtol=5e-3, atol=5e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### LayerNorm Backwards Pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute references values for backwards graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reference backwards operation using PyTorch\n",
    "target = torch.randn_like(out_expected)\n",
    "criterion = torch.nn.MSELoss()\n",
    "loss = criterion(out_expected, target)\n",
    "\n",
    "out_expected.retain_grad()\n",
    "x_gpu.retain_grad()\n",
    "scale_gpu.retain_grad()\n",
    "bias_gpu.retain_grad()\n",
    "\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build backwards graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bwd_graph = cudnn.pygraph(\n",
    "    handle=handle,\n",
    "    intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")\n",
    "\n",
    "# Create tensors associated with the backwards graph. DO NOT reuse tensor handles from the forward graph.\n",
    "d_out = bwd_graph.tensor(\n",
    "    name=\"d_out\", dim=x_gpu.size(), stride=x_gpu.stride(), data_type=x_gpu.dtype\n",
    ")\n",
    "\n",
    "x_bwd = bwd_graph.tensor_like(x, name=\"x\")\n",
    "scale_bwd = bwd_graph.tensor_like(scale, name=\"scale\")\n",
    "mean_bwd = bwd_graph.tensor_like(mean, name=\"mean\")\n",
    "inv_var_bwd = bwd_graph.tensor_like(inv_var, name=\"inv_var\")\n",
    "\n",
    "# Add the layernorm backwards operation\n",
    "(d_x, d_scale, d_bias) = bwd_graph.layernorm_backward(\n",
    "    name=\"d_layernorm\",\n",
    "    grad=d_out,\n",
    "    input=x_bwd,\n",
    "    scale=scale_bwd,\n",
    "    mean=mean_bwd,\n",
    "    inv_variance=inv_var_bwd,\n",
    ")\n",
    "\n",
    "# Enable outputs.\n",
    "d_x.set_output(True).set_data_type(x_gpu.dtype)\n",
    "d_scale.set_output(True).set_data_type(x_gpu.dtype)\n",
    "d_bias.set_output(True).set_data_type(x_gpu.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the bwd_graph\n",
    "bwd_graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Execute the graph and check correctness against PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create output buffers for gradients\n",
    "d_x_gpu = torch.empty_like(x_gpu)\n",
    "d_scale_gpu = torch.empty_like(scale_gpu)\n",
    "d_bias_gpu = torch.empty_like(bias_gpu)\n",
    "\n",
    "workspace = torch.empty(\n",
    "    bwd_graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8\n",
    ")\n",
    "\n",
    "# For the inputs of the backwards graph (x_bwd, d_out, scale_bwd, mean_bwd, inv_var_bwd), we use the outputs of the forwards graph. For d_out we use pytorches autograd .grad functionality.\n",
    "bwd_graph.execute(\n",
    "    {\n",
    "        x_bwd: x_gpu.detach(),\n",
    "        scale_bwd: scale_gpu.detach(),\n",
    "        d_out: out_expected.grad,\n",
    "        mean_bwd: mean_gpu.detach(),\n",
    "        inv_var_bwd: inv_var_gpu.detach(),\n",
    "        d_x: d_x_gpu,\n",
    "        d_scale: d_scale_gpu,\n",
    "        d_bias: d_bias_gpu,\n",
    "    },\n",
    "    workspace,\n",
    "    handle=handle,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare results and check correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.synchronize()\n",
    "\n",
    "# compare to reference output\n",
    "torch.testing.assert_close(x_gpu.grad, d_x_gpu, atol=2e-4, rtol=2e-4)\n",
    "torch.testing.assert_close(scale_gpu.grad, d_scale_gpu, atol=2e-4, rtol=2e-4)\n",
    "torch.testing.assert_close(bias_gpu.grad, d_bias_gpu, atol=2e-4, rtol=2e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perform Cleanup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cudnn.destroy_handle(handle)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
